import pandas as pd
from datasets import load_dataset
from transformers import pipeline, logging
import torch
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score, precision_recall_fscore_support, classification_report
from sklearn.utils import resample
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################
ds = load_dataset('mlburnham/polnli_hypothesis_stability')
df = ds['test'].to_pandas()
# dictionary pairs to pass through pipeline
hyp_dict = [{'text':df.loc[i, 'premise'], 'text_pair':df.loc[i, 'hypothesis']} for i in df.index]
alt1_dict = [{'text':df.loc[i, 'premise'], 'text_pair':df.loc[i, 'alt1']} for i in df.index]
alt2_dict = [{'text':df.loc[i, 'premise'], 'text_pair':df.loc[i, 'alt2']} for i in df.index]
alt3_dict = [{'text':df.loc[i, 'premise'], 'text_pair':df.loc[i, 'alt3']} for i in df.index]

def label_docs(model, docs_dict, batch_size = 16, device = 'cuda'):
    """
    Passes documents through the pipeline, returns list of entail, not_entail labels
    """
    pipe = pipeline(task = 'text-classification', model = model,
                    batch_size = batch_size, device = device,
                    max_length = 512, truncation = True)
    res = pipe(docs_dict)
    res = [result['label'] for result in res]
    return res
    
#######################
## Label All Hypotheses 
#######################
models = ["mlburnham/Political_DEBATE_DeBERTa_base_v1.1",
         "mlburnham/Political_DEBATE_DeBERTa_large_v1.1"]

# list of dictionaries
doc_dicts = [hyp_dict, alt1_dict, alt2_dict, alt3_dict]

# loop through models and dictionaries
for modname, modname_short in zip(models, ['base', 'large']):
    for doc_dict, dictname in zip(doc_dicts, ['hyp', 'alt1', 'alt2', 'alt3']):
        colname = f"{modname_short}_{dictname}"
        res = label_docs(modname, doc_dict)
        
        df[colname] = [0 if label == 'entailment' else 1 for label in res]
        print(modname, dictname,  'complete')

########################
## Majority Voting Label 
########################
# get majority for base
df['base_final'] = df[['base_hyp', 'base_alt1', 'base_alt2', 'base_alt3']].mode(axis=1).iloc[:, 0]
# get majority for large
df['large_final'] = df[['large_hyp', 'large_alt1', 'large_alt2', 'large_alt3']].mode(axis=1).iloc[:, 0]

# flag tied mode
df['base_tiemode'] = df[['base_hyp', 'base_alt1', 'base_alt2', 'base_alt3']].apply(lambda row: row.mode().shape[0] > 1, axis=1).astype(int)
df['large_tiemode'] = df[['large_hyp', 'large_alt1', 'large_alt2', 'large_alt3']].apply(lambda row: row.mode().shape[0] > 1, axis=1).astype(int)

df.to_csv("./data/alt_hypotheses.csv", index = False)